使用spark mllib 随机森林算法对文本进行多分类

您所在的位置:网站首页 随机森林 文本分类 使用spark mllib 随机森林算法对文本进行多分类

使用spark mllib 随机森林算法对文本进行多分类

2024-07-14 02:17| 来源: 网络整理| 查看: 265

1、数据准备

20W人工标注文本数据,样本如下:

1#k-v#*亮亮爱宠*波波宠物指甲钳指甲剪附送锉刀适用小型犬及猫特价 1#k-v#*顺丰包邮*宠物药品圣马利诺PowerIgG免疫力球蛋白犬猫细小病毒 1#k-v#*包邮*法国罗斯蔓草本精华宠物浴液薰衣草护色润泽香波拍套餐 1#k-v#*包邮*家朵102宠物沐浴液 1#k-v#*包邮*家朵102宠物沐浴液猫 2、分词

使用ansj包对文本数据去除停用词分词。代码如下:

import java.io.File; import java.io.IOException; import java.util.HashSet; import java.util.List; import java.util.Set; import org.ansj.domain.Result; import org.ansj.domain.Term; import org.ansj.splitWord.analysis.ToAnalysis; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; public class Seg{ private static Set stopwords = new HashSet(); static{ File f = new File(""); try { List lines = FileUtils.readLines(f); for(String str : lines){ stopwords.add(str); } } catch (IOException e) { e.printStackTrace(); } } public static void main(String[] args) throws IOException { File f = new File(""); File resultFile = new File(""); List lists = FileUtils.readLines(f); int count = 0; for(String str : lists){ count++; String index = str.split("#k-v#")[0]; // System.out.println(count + " " + Integer.parseInt(index)); Result res = ToAnalysis.parse(str.split("#k-v#")[1]); List terms = res.getTerms(); String wordStr = ""; for(Term t : terms){ String word = t.getName(); if(word.length()>1&&!stopwords.contains(word)){ wordStr = wordStr + " " + word; } } if(!StringUtils.isEmpty(wordStr)){ FileUtils.write(resultFile, index + "#k-v#" + wordStr + "\n" , true); } System.out.println(count); } } 3、对分词数据进行tfidf转换

这里我用到工具是sparkmllib的tfidf带的包,代码如下:

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StringType import org.apache.spark.sql.Row //case class FileRecord(index:Int,seg: String) object TfIdf { def main(args: Array[String]) { val conf = new SparkConf().setAppName("TfIdfExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val schemaString = "index seg" val fields = schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, nullable = true)) val schema = StructType(fields) val srcRDD = sc.textFile("/tmp/seg_src.txt", 1).map(x => x.split("#k-v#")).map(attributes => Row(attributes(0), attributes(1).trim)) val sentenceData = sqlContext.createDataFrame(srcRDD, schema).toDF("label", "seg") val tokenizer = new Tokenizer().setInputCol("seg").setOutputCol("words") val wordsData = tokenizer.transform(sentenceData) val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(26) val featurizedData = hashingTF.transform(wordsData) val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") val idfModel = idf.fit(featurizedData) val rescaledData = idfModel.transform(featurizedData) rescaledData.select("features", "label").take(3).foreach(println) rescaledData.select("features", "label").write.format("json").save("/tmp/tfidf.model") } }

得到的是json数据格式,示例数据如下:

{"features":{"type":0,"size":26,"indices":[0,5,6,7,9,10,14,17,21],"values":[2.028990788466258,1.8600672974067514,1.8464729103095205,2.037399707294254,1.908861495143531,3.6260607728633083,2.0363086347259687,1.8261747092361593,2.0640809711702492]},"label":"1"} {"features":{"type":0,"size":26,"indices":[7,8,17],"values":[4.074799414588508,2.1216332358971366,1.8261747092361593]},"label":"1"} 4、json数据转libsvm数据格式

因为sparkmllib中随机森林算法需libsvm数据格式,故进行转换,代码如下:

File f = new File("D:/sogouOutput/json_feature"); File libsvmFile = new File("D:/sogouOutput/libsvm_feature"); List features = FileUtils.readLines(f); for(String str : features){ JSONObject obj = new JSONObject(str); String label = obj.getString("label"); JSONArray indexArr = obj.getJSONObject("features").getJSONArray("indices"); JSONArray valueArr = obj.getJSONObject("features").getJSONArray("values"); int length = indexArr.length(); String line = label + " "; Map indiceAndValue = new TreeMap(); for(int i=0;i 4 distinct values are treated as continuous. val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(26).fit(data) // Split the data into training and test sets (30% held out for testing) val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a RandomForest model. val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) // Convert indexed labels back to original labels. val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) // Select (prediction, true label) and compute test error val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] println("Learned classification forest model:\n" + rfModel.toDebugString) // $example off$ sc.stop() } }

在运行过程中,val labelIndexer = new StringIndexer().setInputCol(“label”).setOutputCol(“indexedLabel”).fit(data) 这句代码会报错:

Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order

经查找是因为特征索引不能为0,看它源代码是index作了-1处理导致的。

private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { val items = line.split(' ') val label = items.head.toDouble val (indices, values) = items.tail.filter(_.nonEmpty).map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. val value = indexAndValue(1).toDouble (index, value) }.unzip // check if indices are one-based and in ascending order var previous = -1 var i = 0 val indicesLength = indices.length while (i < indicesLength) { val current = indices(i) require(current > previous, s"indices should be one-based and in ascending order;" + s""" found current=$current, previous=$previous; line="$line"""") previous = current i += 1 } (label, indices.toArray, values.toArray) }


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3